import numpy as np
import matplotlib
import matplotlib.pyplot as plt

def spike_preprocessing(unit_names1, unit_names2, spike1, spike2):
    """
    unit_names1: unit names in the first dataset
    unit_names2: unit names in the second dataset
    spike1: a list, the spike counts for each trial in the first dataset
    spike2: a list, the spike counts for each trial in the second dataset
    """
    all_unit_names = np.sort(list(set(unit_names1)|set(unit_names2)))
    N_unit = len(all_unit_names)
    
    idx = [list(all_unit_names).index(e) for e in unit_names1]
    spike1_ = [np.zeros((s.shape[0], N_unit)) for s in spike1]
    for k in range(len(spike1)):
        spike1_[k][:, idx] = spike1[k]
        
    idx = [list(all_unit_names).index(e) for e in unit_names2]
    spike2_ = [np.zeros((s.shape[0], N_unit)) for s in spike2]
    for k in range(len(spike2)):
        spike2_[k][:, idx] = spike2[k]
    return spike1_, spike2_

def get_pred_cursor_pos(model, day_spike_format, day_cursor_format, domain_flag=False):
    y_pred_format, _, _ = model(day_spike_format, day_cursor_format, domain_flag=domain_flag)
    return y_pred_format

def plot_cursor_position_fig(pred_cursor_pos, gt_spike, gt_target_dir, window_size, fig_para, save_fig_file):
    cur_idx = 0
    line_width = fig_para['line_width']
    cmap = matplotlib.cm.get_cmap("Dark2", 8)
    #plt.cla()
    plt.figure(figsize=fig_para['f_size'])
    for ti in range(0, len(gt_spike)):
        trial_len = gt_spike[ti].shape[0] - window_size + 1
        target_dir_idx = int(gt_target_dir[ti] / 45 + 3)
        
        # Jango
        # plt.xlim(-4, 4)
        # plt.ylim(-5, 4)
        # Mihili
        # plt.xlim(-6, 13)
        # plt.ylim(-42, -21)
        # plt.axis('off')
        plt.plot(pred_cursor_pos[cur_idx:cur_idx+trial_len, 0], pred_cursor_pos[cur_idx:cur_idx+trial_len, 1], color=cmap.colors[target_dir_idx, :], linewidth=line_width)
        cur_idx += trial_len
    # plt.xticks([])
    # plt.yticks([])
    plt.savefig(save_fig_file)
    return

def plot_cursor_position(pred_cursor_pos, gt_spike, gt_target_dir, window_size, save_fig_file):
    cur_idx = 0
    line_width = 3
    cmap = matplotlib.cm.get_cmap("Dark2", 8)
    plt.figure(figsize=(6, 6))
    plt.cla()
    for ti in range(0, len(gt_spike)):
        trial_len = gt_spike[ti].shape[0] - window_size + 1
        target_dir_idx = int(gt_target_dir[ti] / 45 + 3)
        
        # Jango
        # plt.xlim(-4, 4)
        # plt.ylim(-5, 4)
        # Mihili
        plt.xlim(-10, 10)
        plt.ylim(-10, 10)
        plt.axis('off')
        # if abs(pred_cursor_pos[cur_idx, 0]) < 1 or abs(pred_cursor_pos[cur_idx, 1]) < 1:
        plt.plot(pred_cursor_pos[cur_idx:cur_idx+trial_len, 0], pred_cursor_pos[cur_idx:cur_idx+trial_len, 1], color=cmap.colors[target_dir_idx, :], linewidth=line_width)
        cur_idx += trial_len
    # plt.xticks([])
    # plt.yticks([])
    plt.savefig(save_fig_file)
    return

def plot_r2_score(avg_list, std_list, fig_name, fig_para, save_fig_file):
    cmap = matplotlib.cm.get_cmap(fig_para['cmap_name'], 5)
    # plt.cla()
    x = list(range(avg_list[0].shape[0]))
    plt.figure(figsize=fig_para['f_size'])
    for i in range(len(avg_list)):
        if i == 0:
            line_color = 'royalblue'
        else:
            line_color = cmap.colors[i, :]
        plt.errorbar(x, avg_list[i], yerr=std_list[i], fmt='o-', ecolor=line_color, color=line_color, ms=fig_para['ms_size'], elinewidth=fig_para['er_width'], alpha=fig_para['alpha'])
    my_x_ticks = np.arange(fig_para['x_l'], fig_para['x_h'], fig_para['x_inter'])
    plt.xticks(my_x_ticks)
    my_y_ticks = np.arange(fig_para['y_l'], fig_para['y_h'], fig_para['y_inter'])
    plt.yticks(my_y_ticks)
    plt.xlabel('Session ID')
    plt.ylabel(r'$R^{2}$')
    if fig_para['legend_flag']:
        plt.legend(fig_para['legend_title'], fontsize=fig_para['legend_font_size'], loc=fig_para['legend_loc'])
    plt.title(fig_name)
    # plt.show()

    plt.savefig(save_fig_file, bbox_inches='tight')
    return

def plot_avg_curve(avg_arr, fig_name, fig_para, save_fig_file):
    cmap = matplotlib.cm.get_cmap(fig_para['cmap_name'], 5)
    x = list(range(1, avg_arr.shape[1]+1))
    plt.figure(figsize=fig_para['f_size'])
    for i in range(avg_arr.shape[0]):
        plt.plot(x, avg_arr[i, :], marker='o', color=cmap.colors[i, :], ms=fig_para['ms_size'])
    my_x_ticks = np.arange(fig_para['x_l'], fig_para['x_h'], fig_para['x_inter'])
    plt.xticks(my_x_ticks)
    my_y_ticks = np.arange(fig_para['y_l'], fig_para['y_h'], fig_para['y_inter'])
    plt.yticks(my_y_ticks)
    plt.legend(fig_para['legend_title'], fontsize=fig_para['legend_font_size'], loc=fig_para['legend_loc'])
    plt.title(fig_name)
    plt.xlabel('Session ID')
    plt.ylabel(r'$R^{2}$')
    # plt.show()

    plt.savefig(save_fig_file, bbox_inches='tight')
    return

from sklearn.manifold import TSNE
def plot_latent_tsne(latent_z_src, latent_z_tgt, latent_para, fig_name, fig_para, save_fig_file, legend_flag=True):
    t_sne = TSNE(n_components=2, random_state=500)

    src_start, src_num = latent_para['src_start'], latent_para['src_num']
    tgt_start, tgt_num = latent_para['tgt_start'], latent_para['tgt_num']

    latent_z = np.concatenate((latent_z_src[src_start:src_start+src_num, :], latent_z_tgt[tgt_start:tgt_start+tgt_num, :]), axis=0)
    latent_z_tsne = t_sne.fit_transform(latent_z)
    
    plt.figure(figsize=fig_para['f_size'])
    plt.scatter(latent_z_tsne[:src_num, 0], latent_z_tsne[:src_num, 1], color=fig_para['src_c'], alpha=fig_para['alpha'], s=fig_para['s_size'], edgecolor=fig_para['src_c'])
    plt.scatter(latent_z_tsne[src_num:, 0], latent_z_tsne[src_num:, 1], color=fig_para['tgt_c'], alpha=fig_para['alpha'], s=fig_para['s_size'], edgecolor=fig_para['tgt_c'])

    plt.title(fig_name, fontsize=fig_para['legend_font_size'])
    if legend_flag:
        plt.legend(fig_para['legend_title'], fontsize=fig_para['legend_font_size'])
    plt.xticks([])
    plt.yticks([])
    # plt.show()

    plt.savefig(save_fig_file, format='svg', bbox_inches='tight')
    return

def plot_hyper_r2_score(avg_list, std_list, fig_name, fig_para, save_fig_file):
    cmap = matplotlib.cm.get_cmap(fig_para['cmap_name'], 5)

    plt.figure(figsize=fig_para['f_size'])
    for i in range(len(avg_list)):
        line_color = cmap.colors[i, :]
        plt.errorbar(fig_para['x_list'], avg_list[i], yerr=std_list[i], fmt='o-', ecolor=line_color, color=line_color, ms=fig_para['ms_size'], elinewidth=fig_para['er_width'], alpha=fig_para['alpha'], linewidth=fig_para['plot_line_width'])
    # my_x_ticks = np.arange(fig_para['x_l'], fig_para['x_h'], fig_para['x_inter'])
    # plt.xticks(my_x_ticks)
    plt.xticks(size=fig_para['tick_font_size'])
    my_y_ticks = np.arange(fig_para['y_l'], fig_para['y_h'], fig_para['y_inter'])
    plt.yticks(my_y_ticks, size=fig_para['tick_font_size'])

    plt.title(fig_name)
    plt.xlabel(fig_para['hyper_name'], fontdict={'size': fig_para['tick_font_size']})
    plt.ylabel(r'$R^{2}$', fontdict={'size': fig_para['tick_font_size']})
    if fig_para['legend_flag']:
        plt.legend(fig_para['legend_title'], fontsize=fig_para['legend_font_size'], loc=fig_para['legend_loc'])
    
    plt.savefig(save_fig_file, bbox_inches='tight')
    return